#!/usr/bin/python
#
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import click
import glob
import pickle
import numpy as np
from parse_mjl import parse_mjl_logs, viz_parsed_mjl_logs
import adept_envs
import time as timer
import skvideo.io
import gym
import cv2
from PIL import Image
import os

render_buffer = []  # rendering buffer

def viewer(env,
           mode='initialize',
           filename='video',
           video_resolution=(1920, 2560),
           camera_id=0,
           render=None):
    if render == 'onscreen':
        env.mj_render()

    elif render == 'offscreen':

        global render_buffer
        if mode == 'initialize':
            render_buffer = []
            mode = 'render'

        if mode == 'render':
            curr_frame = env.render(mode='rgb_array')
            rgb_array = np.array(curr_frame)
            rgb_array = Image.fromarray(rgb_array)
            rgb_array = np.array(rgb_array)
            bgr_array = cv2.cvtColor(rgb_array, cv2.COLOR_RGB2BGR)
            render_buffer.append(bgr_array)

        if mode == 'save':
            render_buffer.pop()
            video_fourcc_avi = cv2.VideoWriter_fourcc(*'FFV1')
            video_out_avi = cv2.VideoWriter(filename, video_fourcc_avi, 30.0, (render_buffer[0].shape[1], render_buffer[0].shape[0]))
            mp4_filename = filename.replace('.avi', '.mp4')
            video_fourcc_mp4 = cv2.VideoWriter_fourcc(*'mp4v')
            video_out_mp4 = cv2.VideoWriter(mp4_filename, video_fourcc_mp4, 30.0, (render_buffer[0].shape[1], render_buffer[0].shape[0]))
            for frame in render_buffer:
                video_out_avi.write(frame)
                video_out_mp4.write(frame)
            video_out_avi.release()
            video_out_mp4.release()
            print(f"Saved AVI video: {filename}")
            print(f"Saved MP4 video: {mp4_filename}")

    elif render == 'None':
        pass

    else:
        print("unknown render: ", render)


# view demos (physics ignored)
def render_demos(env, data, filename='demo_rendering.avi', render=None):
    FPS = 30
    render_skip = max(1, round(1. / \
        (FPS * env.sim.model.opt.timestep * env.frame_skip)))
    t0 = timer.time()

    viewer(env, mode='initialize', render=render)
    for i_frame in range(data['ctrl'].shape[0]):
        env.sim.data.qpos[:] = data['qpos'][i_frame].copy()
        env.sim.data.qvel[:] = data['qvel'][i_frame].copy()
        env.sim.forward()
        if i_frame % render_skip == 0:
            viewer(env, mode='render', render=render)
            # print(i_frame, end=', ', flush=True)

    viewer(env, mode='save', filename=filename, render=render)
    print("time taken = %f" % (timer.time() - t0))


# playback demos and get data(physics respected)
def gather_training_data(env, data, filename='demo_playback.avi', render=None):
    env = env.env
    FPS = 30
    render_skip = max(1, round(1. / \
        (FPS * env.sim.model.opt.timestep * env.frame_skip)))
    t0 = timer.time()

    env.reset()
    init_qpos = data['qpos'][0].copy()
    init_qvel = data['qvel'][0].copy()
    act_mid = env.act_mid
    act_rng = env.act_amp

    env.sim.data.qpos[:] = init_qpos
    env.sim.data.qvel[:] = init_qvel
    env.sim.forward()
    viewer(env, mode='initialize', render=render)
    # print(env._get_obs())
    

    path_obs = None
    for i_frame in range(data['ctrl'].shape[0] - 1):
        obs = env._get_obs()
        ctrl = (data['ctrl'][i_frame] - obs[:9])/(env.skip*env.model.opt.timestep)
        act = (ctrl - act_mid) / act_rng
        act = np.clip(act, -0.999, 0.999)
        next_obs, reward, done, env_info = env.step(act)
        if path_obs is None:
            path_obs = obs
            path_act = act
        else:
            path_obs = np.vstack((path_obs, obs))
            path_act = np.vstack((path_act, act))
        if i_frame % render_skip == 0:
            viewer(env, mode='render', render=render)
            # print(i_frame, end=', ', flush=True)
    if render:
        viewer(env, mode='save', filename=filename, render=render)

    t1 = timer.time()
    print("time taken = %f" % (t1 - t0))
    return path_obs, path_act, init_qpos, init_qvel


# MAIN =========================================================
@click.command(help="parse tele-op demos")
@click.option('--env', '-e', type=str, help='gym env name', required=True)
@click.option(
    '--demo_dir',
    '-d',
    type=str,
    help='directory with tele-op logs',
    required=True)
@click.option(
    '--skip',
    '-s',
    type=int,
    help='number of frames to skip (1:no skip)',
    default=1)
@click.option('--graph', '-g', type=bool, help='plot logs', default=False)
@click.option('--save_logs', '-l', type=bool, help='save logs', default=False)
@click.option(
    '--view', '-v', type=str, help='render/playback', default='render')
@click.option(
    '--render', '-r', type=str, help='onscreen/offscreen', default='onscreen')
def main(env, demo_dir, skip, graph, save_logs, view, render):

    base_save_dir = "/home/sachit/Desktop/COD892/Data_Franka_Kitchen"
    for index, sub_dir in enumerate(glob.glob(demo_dir + "*/")):  # Iterate through subdirectories
        print(f"Processing subdirectory: {sub_dir}")
        for ind, file in enumerate(glob.glob(sub_dir + "*.mjl")):
            gym_env = gym.make(env)
            gym_env.seed(42)
            output_dir = os.path.join(base_save_dir, f"{index + 1}.{ind + 1}")
            os.makedirs(output_dir, exist_ok=True)
            print("processing: " + file, end=': \n')
            try:
                data = parse_mjl_logs(file, skip) # This gives error sometimes when struct size is very big
            except:
                continue
            if (graph):
                print("plotting: " + file)
                viz_parsed_mjl_logs(data)
            if (save_logs):
                pickle.dump(data, open(file[:-4] + ".pkl", 'wb'))
            if view == 'render':
                render_demos(
                    gym_env,
                    data,
                    filename=data['logName'][:-4] + '_demo_render.avi',
                    render=render)

            elif view == 'playback':
                try:
                    obs, act,init_qpos, init_qvel = gather_training_data(gym_env, data,\
                    filename=os.path.join(output_dir, "camera_2.avi") , render=render)
                except Exception as e:
                    print(e)
                    continue
                path = {
                    'observations': obs,
                    'actions': act,
                    'goals': obs,
                    'init_qpos': init_qpos,
                    'init_qvel': init_qvel
                }
                path_save_path = os.path.join(output_dir, "data.pkl")
                pickle.dump(path, open(path_save_path, 'wb'))
                print(f"Saved playback path: {path_save_path}")

if __name__ == '__main__':
    main()